{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Segment Anything Image Segmentation examples using ONNX\n",
    "#\n",
    "# Before running this code, ensure that you used the export_onnx.ipynb notebook\n",
    "# to export both SAM image encoder and SAM masks decoder to vit_b_encoder.onnx\n",
    "# and vit_b_decoder.onnx."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "!pip install onnxruntime\n",
    "!pip install Pillow\n",
    "!pip install numpy"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [],
   "source": [
    "import onnxruntime as ort\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "from copy import deepcopy"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "outputs": [
    {
     "data": {
      "text/plain": "(612, 415)"
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# LOAD IMAGE\n",
    "img = Image.open(\"cat_dog.jpg\").convert(\"RGB\")\n",
    "img.size"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "outputs": [
    {
     "data": {
      "text/plain": "(1024, 694)"
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 1. PREPROCESS IMAGE FOR ENCODER\n",
    "\n",
    "# Resize image preserving aspect ratio using 1024 as a long side\n",
    "orig_width, orig_height = img.size\n",
    "resized_width, resized_height = img.size\n",
    "\n",
    "if orig_width > orig_height:\n",
    "    resized_width = 1024\n",
    "    resized_height = int(1024 / orig_width * orig_height)\n",
    "else:\n",
    "    resized_height = 1024\n",
    "    resized_width = int(1024 / orig_height * orig_width)\n",
    "\n",
    "img = img.resize((resized_width, resized_height), Image.Resampling.BILINEAR)\n",
    "\n",
    "img.size"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "outputs": [
    {
     "data": {
      "text/plain": "(1, 3, 694, 1024)"
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Prepare input tensor from image\n",
    "input_tensor = np.array(img)\n",
    "\n",
    "# Normalize input tensor numbers\n",
    "mean = np.array([123.675, 116.28, 103.53])\n",
    "std = np.array([[58.395, 57.12, 57.375]])\n",
    "input_tensor = (input_tensor - mean) / std\n",
    "\n",
    "# Transpose input tensor to shape (Batch,Channels,Height,Width\n",
    "input_tensor = input_tensor.transpose(2,0,1)[None,:,:,:].astype(np.float32)\n",
    "\n",
    "input_tensor.shape"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "outputs": [
    {
     "data": {
      "text/plain": "(1, 3, 1024, 1024)"
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Make image square 1024x1024 by padding short side by zeros\n",
    "if resized_height < resized_width:\n",
    "    input_tensor = np.pad(input_tensor,((0,0),(0,0),(0,1024-resized_height),(0,0)))\n",
    "else:\n",
    "    input_tensor = np.pad(input_tensor,((0,0),(0,0),(0,0),(0,1024-resized_width)))\n",
    "\n",
    "input_tensor.shape"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "outputs": [
    {
     "data": {
      "text/plain": "(1, 256, 64, 64)"
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 2. GET IMAGE EMBEDDINGS USING IMAGE ENCODER\n",
    "encoder = ort.InferenceSession(\"vit_b_encoder.onnx\")\n",
    "outputs = encoder.run(None,{\"images\":input_tensor})\n",
    "embeddings = outputs[0]\n",
    "\n",
    "embeddings.shape"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# 3. DECODE MASKS FROM IMAGE EMBEDDINGS\n",
    "\n",
    "# 3.1 OPTION 1: Use single point as a prompt"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "outputs": [
    {
     "data": {
      "text/plain": "array([[[537.098 , 384.6265],\n        [  0.    ,   0.    ]]], dtype=float32)"
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# ENCODE PROMPT (single point)\n",
    "input_point = np.array([[321,230]])\n",
    "input_label = np.array([1])\n",
    "\n",
    "onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]\n",
    "onnx_label = np.concatenate([input_label, np.array([-1])])[None, :].astype(np.float32)\n",
    "\n",
    "coords = deepcopy(onnx_coord).astype(float)\n",
    "coords[..., 0] = coords[..., 0] * (resized_width / orig_width)\n",
    "coords[..., 1] = coords[..., 1] * (resized_height / orig_height)\n",
    "\n",
    "onnx_coord = coords.astype(\"float32\")\n",
    "onnx_coord"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "outputs": [],
   "source": [
    "# RUN DECODER TO GET MASK\n",
    "onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)\n",
    "onnx_has_mask_input = np.zeros(1, dtype=np.float32)\n",
    "\n",
    "decoder = ort.InferenceSession(\"vit_b_decoder.onnx\")\n",
    "masks,_,_ = decoder.run(None,{\n",
    "    \"image_embeddings\": embeddings,\n",
    "    \"point_coords\": onnx_coord,\n",
    "    \"point_labels\": onnx_label,\n",
    "    \"mask_input\": onnx_mask_input,\n",
    "    \"has_mask_input\": onnx_has_mask_input,\n",
    "    \"orig_im_size\": np.array([orig_height, orig_width], dtype=np.float32)\n",
    "})"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "outputs": [
    {
     "data": {
      "text/plain": "(415, 612)"
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# POSTPROCESS MASK\n",
    "mask = masks[0][0]\n",
    "mask = (mask > 0).astype('uint8')*255\n",
    "mask.shape"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "outputs": [
    {
     "data": {
      "text/plain": "<PIL.Image.Image image mode=L size=612x415>",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAmQAAAGfCAAAAADUx5EnAAAKtklEQVR4Ae3d7XKbOBgG0GSn93/L3WSa2BYIkJBkJHT6pwb0xXmfEY7j7X58+EOAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQuKvA511vrPy+/v4MgajUkuCW4G/Gvq6fRfoe4mzfrWWNeJ5BvGovETsZlMcIiAlEQ/ZIyM/VbKZggOze0TWNe/K/cZf+zpUHmUmYOLd9wpADNxGyWPFKM7LovziMzXjrc0KWVt6snKwar06kTXqXVkIWqWQsE7Fzka5ORQSELIISPZWesvSW0Ynud1LI1jVtEJIGQ67X3e0ZIUsuTWpQUtslTzx8QyFLL2FaetJapc96g5ZCVrmIGxnbOF158k6HE7JVYbbzsH1lNYgTLwJC9oLhZRsBIavrarOLeApZBGXzlAht0uxdELI9HdeqCAhZFuPRVnZ0PWuy2zQWstuUst8bEbK82uxvVftX82a6UWshq1jMnYxN/eVYIVuFbD8PO0HaubSaZKoTQjZVua+5WSG7xn2qWf9Mdbc1bvbroRh/oHpabvHaybZkds7/jeUpdm5njJkuCdmpaq9jtp+x+N53auoBO3lcnizad6ge0dlP2MkZ7tNNyApqKVtpeB6XaU5aFQgIWQGermkCQpbmpFWBgJAV4OmaJiBkaU5aFQgI2QrPz4wrksITQlYIqPuxgJAdG2lRKCBkhYC6HwsI2bGRFoUCQlYIqPuxgJAdG2lRKCBkK8DHdytWV5w4JyBk59z0yhAQsgwsTc8JCNk5N70yBIQsA0vTcwJCds5NrwwBIcvA0vScgJCdc9MrQ0DIMrA0PScgZOfc9MoQELIMLE3PCQjZOTe9MgSELANL03MCQnbOTa8MASHLwNL0nICQrdz810orksITQlYIqPuxgJAtjWxkS5HiYyErJjTAkYCQHQm5XiwgZMWEBjgSELIjoSrX536jJ2RVQmSQPQEh29NxrYqAkFVhPBxk6uelkB3mQ4NSASErFUzsP/NWJmSJIdHsvICQLe1a/VMYE29lQrYM2cRhWFLUOhayWpLG2RQQsk0aF2oJCFktSeNsCgjZJo0LtQSEbCHpff8CpMKhkFVATBti3vgKWVpCtCoQELICPF3TBIQszUmrAgEhK8DTNU1AyNKctCoQELIQb94fAUOHqkdCVpVzb7BWX+/Ym7OPa0LWRx1uvQohe1t5530SC9nbQjbvREI2b+3fdudC9jbqeScSsnlr/7Y7F7KQuuXnDNO+8xeyMGSOGggIWQPUrSFn3cqEbCsRzlcTELJqlAbaEhCyLZkW5yd9XgpZizAZMxAQsoDj46PlZxgfH3NuZUK2CJnD+gJCVt90b8QptzIh24uEa1UEhKwKY/ogM25lQpaejzotJ0yZkNWJTsYo86Ws7U/sGfS9NH1LBCZTt5Ndke6/b4nyFXcWnVPIoizNT06VMiEL8zRV8cNbb3ckZO1sd0eeKc1CthsFF2sICFmgONP+Etx40wMha8q7M/hEef6zw+DSpQKPFA7/qZqQXRqkrckfAftu8O9g4KgJ2VadLzsfBOy5ir/jpmzclT/5K77aKHDFGV6G+rbPmXDYWg278JdiVXyZU/OK06YONWi1Bl12alUy23Wesa+7GbJePsLIzOHFzYf81bqQXZya7On7321XtyRkKxInagsIWW3R5uONt5UJWfNQVJ9guPdlQvaSgWH2iGEW+g9XyF5C5mUbASF7ug60Pwy01C9fIXuGbKRXQ6VsyE+Q26RhqLp9EYxTOjtZm8Qa9UVAyH4xRtvIsr7A8XuT1/wtZNe4TzWrkP2Ue7iNLO+7aJeGWsgu5Z9jciEbuM6j7L5CNnDIRlm6kI1Sqdg6B9nKhCxWvGHOjZEyIRsmUOMuVMh+ajfOL2mCsA2xlQlZUDMHLQSErIXqG8ccYSsTsjcGYtaphOy38oO+Kftdfs9/C1nP1bnJ2oTsUchBt7IB3pQJ2SNko77oP2VC9szWoFtZ/99eFLJnyAb61vzLogd4KWQDFOloib0/MIXstYIemK8a1V4LWUA5asqCm+juQMjCkmSl7PMzq3k400RHlBbFTn5/8yOX3H4xT+XDrutoJ1tUO7Vav+1sZgvAyOGvVeTStKdSNqfQLaVHW85wPW3nyh6968Vl302lDrHMfEM9z6/ZntcqLSJvmPWC8vo3be1xGeHdqtjj/OPFs/PFj82LM/50iL0SsphKJESvzeKXL47Z6wI7ey1k0YLEY/Tv907bYdroFZ1hppNctqq9eAClQS06bY3d4nzaAlvMfDhmx0s7XPt7Gnzl5vMj+X/Rdl3KOq5kx0t7T4YqzyJkEVDvySIoBae237AVDDp6VyGrXcGrng3X7aGHgkJ2SJTb4KqU5a7zfe2F7H3W084kZPVLbytbmArZAqTGoZSFikIWetQ5krLAUcgCjpEP+v3xUsia5MpW9soqZK8aXjcRELImrFcM2u/uKWRN8tDv+6Mmt3swqJAdALlcLiBk5YbrEWxkgYmQBRwOWggIWQtVYwYCQhZwVDq44ge9K+ZM5BKyRKjem3WcsQ8h6z09aevrOWNCllZDrUoE7GQlevomCQhZElPvjbp+WnpcNomPD2MDVjtZwOGghYCQtVDt++nV4o53xxSyXZ6TFz0uAzghCzgctBAQshaq7x6z88ezkLUIROdFb3HLe2MK2Z7O6WvNUjbkP+jSTON0fW7Ssfp7/0elViM/rvRK1/0Ce4U7WtcqCkcddq4virQYenF1Z5yrLvW/wqtkSuddROHscLEChUPHWpydrU0/78nauH7/G6AV/iS8BasyT4Wl7gwxwBJ3Vt/5pXDHyV/sfnG+Rt9vkD9hox6DLLPR3Tcftihmt6nNbW6keV4KJ8iK282qcrPbKUxC2+6JObtfSe53R22DUjz6UdLuWJA73lNxENoOsBOzm1bjprfVNiYVRo8k7b6luO+dVUhC2yGCnKlDW2yjEyBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIdCzwP166lGUgg0MeAAAAAElFTkSuQmCC\n",
      "image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAGfAmQBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+iiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiug8F+E7zxr4otdFs38rzcvNOULrDGoyzED8AM4BYqMjOa+k9K+BHgey0u3t7/TpNQu0TEt09xLGZW7najgKOwHoBkk5J4T4ofA+30zTptc8Jp5draxPNeWk1wTtRQvMW4ZOAHZtzf7vpXg9FFFFbHhjwxqni7XIdI0iDzbiTlmbhIkHV3PZRkfmAASQD7non7Nmmpb7te1y7mnZEOywVYljbHzDc4YuM4wcL06c8eYfEb4Xal8PXtppbuO/0+5cpFcxxMhVgoOHHIUnLbQGOQpPGMVwdFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFfSf7OXh6G28Nah4gltpFu7u4NvFJIgAMKAHKHGcFywbBwTGO617ZRXyp8bfh8vhPxAur6bBHHo+pOdkUMbBbaUAbk9AG5ZQCP4gAAvPldFFFfXfwg+H83gTw1MdQ8v+1tQdZbkIxIiVR8kec4JXLEkd2IyQAT6JXN+P9AbxR4D1nR4lkaee3LQIjKpeVCHjXLcAFlUHOOCeR1r4kooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooor6/8Agl/ySHQv+3j/ANKJK9Aorj/id4R/4TTwNe6bEu6+i/0my5x++QHC/eA+YFkyTgbs9q+MKKK7j4U+D7jxd45sU+zeZptlKlxfO8QeMIpyEYEgHeV2454JOCFNfY9FRzzw2tvLcXEscMESF5JJGCqigZJJPAAHOa+BKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKK+n/wBnH/knmof9hWT/ANFRV7BRRXifjf4AQ61e3uraFq8kWoXdxJcSw32DEzO+4hWRcoBluobPA45NeOan8K/HOkeV9p8M30nm52/ZFFzjGM58ott698Z5x0NR6b8MfG+q3DQW/hjUkdULk3UJt1xkDhpNoJ56Zz19DX1H8NfBH/CBeEl0qS6+03Uspubl1GEEjKqlU4ztAUDJ5PJ4zgdhRXJ/E7UodK+GXiK4nWRkeye3AQAndKPKU8kcbnBPtnr0r4sooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooor6r/AGfdNmsfhkLiVoyl/ey3EQUnIUBYsNx13RseM8EfQeqUUUUUUUUV5/8AG3/kkOu/9u//AKUR18gUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUV9X/AAB1P7f8L4LbyfL/ALPu5rbduz5mSJd2Mcf63GOfu574HqFFFFFFFFFef/G3/kkOu/8Abv8A+lEdfIFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFfT/7OP8AyTzUP+wrJ/6Kir2CiiiiiiiivP8A42/8kh13/t3/APSiOvkCiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiivr/AODHhv8A4Rz4a2G6TfNqX/Ewkw2VXzFXaBwMfIEyOfm3c4xXoFFFFFFFFFef/G3/AJJDrv8A27/+lEdfIFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFfb/gT/knnhr/sFWv/AKKWugoooooooorL8SaJD4j8Nalo0/lhLy3eIO8YkEbEfK+09SrYYcjkDkV8OX9jcaZqNzYXkfl3VrK8MybgdrqSGGRwcEHpVeiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiivuPwbpTaH4K0TTJbaO2nt7KJJ4k24WXaPM+7wSW3EkdSSec1uUUUUUUUUUV8ifHGCaH4t6u8sUiJMkDxMykB18lFyvqNysMjuCO1ed0UUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUVc0nTZtZ1mx0u3aNZ724jt42kJChnYKCcAnGT6GvvOiiiiiiiiiivlj9oa+t7v4lRwwSb5LTT4oZxtI2OWeQDnr8rqePX1zXk9FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFanhrUodG8VaRqlwsjQWV7DcSLGAWKo4YgZIGcD1FfddFFFFFFFFFFfOH7SOjXieIdJ1zZusZbT7HvUE7JEd3wxxgZD8c5O1uOK8Poooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooor7v0LU/7b8PaZq3k+T9utIrnyt27ZvQNtzgZxnGcCtCiiiiiiiiiuP8Aid4R/wCE08DXumxLuvov9JsucfvkBwv3gPmBZMk4G7PavjCiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiirFhY3Gp6jbWFnH5l1dSpDCm4Dc7EBRk8DJI61916TpsOjaNY6XbtI0Flbx28bSEFiqKFBOABnA9BVyiiiiiiiiiivJ/iB8DtL8V3lzq+k3P9matNl5FK5gnfB5YDlWY7csM9CdpJJr5cngmtbiW3uIpIZ4nKSRyKVZGBwQQeQQeMVHRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRXefBiCG5+LegpPFHKgeVwrqGAZYXZTz3DAEHsQDX2HRRRRRRRRRRRRXyx8UPhz4gn+KGonRPDt9NZ6hKk0U0MbPGXkA3ln5CfvN5O4jA54XFY//AApL4h/9C9/5O2//AMco/wCFJfEP/oXv/J23/wDjldBon7O3im+8iTVryx0uF93mJuM80eM4+VfkOSB0fgH1GK7ex/Zu8Nx2ca3+s6rPdDO+SAxxI3JxhSrEcY/iPrx0qnqX7NOmy3CtpfiS7toNgDJdWyzsWyeQylABjHGOx5548g8c/D7WvAuqSQX0Ek1gXAt9QSMiKYHJAz0V8KcoTkYOMjBPJ0UUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUV7R+zdY3EnjLVr9Y82sOn+TI+4fK7yIVGOvIjf8AL3FfS9FFFFFFFFFFFFFFFFFFRzwQ3VvLb3EUc0EqFJI5FDK6kYIIPBBHGK+aPjD8JptAuL/xToyWiaGzxtJaxko1szHacKeChbHQ8b8BQq5rxuiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiivs/4Y+Ef+EL8DWWmyrtvpf8ASb3nP75wMr94j5QFTIODtz3rsKKKKKKKKKKKKKKKKKKKKjnghureW3uIo5oJUKSRyKGV1IwQQeCCOMV8mfGXwHb+CfFED6ZB5OkX8W+3TeW8t0AEiZZix6q2Tj7+B0rzeiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiivvPSdSh1nRrHVLdZFgvbeO4jWQAMFdQwBwSM4PqauUUUUUUUUUUUUUUUUUUUUVy/xA8HW/jjwlc6TL8twuZrOQuVEc4UhC2AcryQeDwTjnBHxZPBNa3EtvcRSQzxOUkjkUqyMDggg8gg8YqOiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiivpv4BeOV1jw+fC95LI2oaaheAsGO+2yAPmJPKs23HAC7AM4OPZKKKKKKKKKKKKKKKKKKKKKK+VPjd4Bm8MeJZNdglkn0/WLiSUlkOYJidzISBjBySvfAYY+XJ8rooooooooooooooooooooooooooooooooooooooooooooooooooooq5pWq32h6pb6nplzJbXlu++KVOqn+RBGQQeCCQcg19f+AfiZovj+3kWzElrqECK09lMRuAIGWQj76BjjPB6ZAyM9pRXN+GvH3hfxfcT2+hatHdTwIHkjMbxttJxkB1BIzgEjOMjPUVc8T+J9L8I6HNq+rz+Vbx8Kq8vK56Ig7scH8iSQASPnDU/2g/GN3qMU1itjYWsUpcW6w+Z5qZBCSM3JwBjKbM5PTjHpegftC+FNSRV1iG70efYWYshnizuwFVkG4kjnlAOCM9M+qWN/Z6nZx3lhdwXdrJnZNBIJEbBIOGHBwQR+FWKKKKKKKKKKKKKK5f4jaN/b/w617Tgk8kjWjSxRwDLvJH+8RQMHOWRRgcnPHNfFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFXNK1W+0PVLfU9MuZLa8t33xSp1U/yIIyCDwQSDkGvUIP2ivGUNvFE9po07ogVpZLeQM5A+8dsgGT14AHoBWpN+0prDRWwh8P2KSLEwuGeZ2Ekm3CsoGNihuSpLEjjcD81eR6B4j1fwtqi6lot9JaXYQpvUBgynqGVgQw6HBB5APUCpPEPirXfFd4LrXNTnvZF+4HICR5AB2oMKudozgDOMnmseius+H3jm+8C+JYb6CWQ2Erql/bgbhNFnnAJA3gElTkYPHQkH678MeJ9L8XaHDq+kT+bbycMrcPE46o47MMj8wQSCCdiiiiiiiiiiiiiiviDxzon/COeOda0lbf7PDBdv5EW/fthY7o+cnPyFTyc+vOa5+iiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiivTPgv45vvDXjGz0dpZJNJ1W4WCW3xu2yv8qSLkjad20Me69iQuPrOiiiiiiiiiiiiivkz482cNt8VLyaK7jne6t4ZZY1xmBggTY3J52or844ccdz5nRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRXpnwGksU+Klmt3DJJO9vMtmyniOXYSWbkceWJB35Ycdx9Z0UUUUUUUUUUUVHPPDa28txcSxwwRIXkkkYKqKBkkk8AAc5r4w+Jev2Pij4h6vrGmNI1nO8axO67S4SNU3Y6gEqSM4OCMgHiuToooooooooooooooooooooooooooooooooooooooooooooooooooooooooooor1D4A2KXfxQgmeOd2tLSaZDEyhUJAjy+eSuHI+XncV7Zr6voooooooooooorh/jBfXGn/AAo1+a1k8uRokhJ2g5SSRI3HPqrMPbPHNfHFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFbHhjxPqnhHXIdX0ifyriPhlblJUPVHHdTgfkCCCAR9l+F/F2i+MdLS/0e9jmBRWlgLASwE5G2RM5U5VvY4yCRzW5RRRRRRWP4h8VaF4Usxda5qcFlG33A5JeTBAO1BlmxuGcA4zk8V5Pqf7SmjxeV/ZPh++us58z7XMlvt6Yxt3579cYwOueOQv/wBozxZcfaUs7DSrSOTeIW8p5JIgc7TkttZhxyVwSOmOK9H+GXxisPEOhyReJ7+x07U7PajzXE8cKXQOcOoJGG4+YAYGQRjO0eiSeJdBh0uHVJdb01NPmfZFdtdIInbnhXzgn5W4B7H0r5w+NXxMXxTfyeHdNFpNo9lcJKl5GWZppVRlYg8DZ85HAOduQ2DivI6KKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKsWN/eaZeR3lhdz2l1HnZNBIY3XIIOGHIyCR+NbH/Cd+MP8Aoa9c/wDBjN/8VXaab+0F43sbdorg6bqDlywlurYqwGB8o8tkGOM9M8nnpj0fSv2jfDN0luup6ZqVjPI+2UoFmiiG7G7dkMRjBOEz1AB795Y/EfwVqFnHdQ+KNKSN84E9ysLjBI5RyGHTuOevSpJ/iD4NtreWd/FWjFI0LsI72N2IAzwqklj7AEntXn+v/tFeHNPdotF0+71Z1cDzGP2eJlK5JUsC+QcDBQd+emfINf8AjH438QO27WJNPg3h1h07MAUhcffB3kHkkFiMn2GOP1LVtS1m4W41TULu+nVAiyXUzSsFyTgFiTjJJx7mqdFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFf//Z\n"
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# VISUALIZE MASK\n",
    "img_mask = Image.fromarray(mask,\"L\")\n",
    "img_mask"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# 3.2 OPTION 2: Use box as a prompt"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "outputs": [
    {
     "data": {
      "text/plain": "array([[[220.86275, 262.5494 ],\n        [428.33987, 543.49396]]], dtype=float32)"
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# ENCODE PROMPT (box)\n",
    "input_box = np.array([132, 157, 256, 325]).reshape(2,2)\n",
    "input_labels = np.array([2,3])\n",
    "\n",
    "onnx_coord = input_box[None, :, :]\n",
    "onnx_label = input_labels[None, :].astype(np.float32)\n",
    "\n",
    "coords = deepcopy(onnx_coord).astype(float)\n",
    "coords[..., 0] = coords[..., 0] * (resized_width / orig_width)\n",
    "coords[..., 1] = coords[..., 1] * (resized_height / orig_height)\n",
    "\n",
    "onnx_coord = coords.astype(\"float32\")\n",
    "onnx_coord"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "outputs": [],
   "source": [
    "# RUN DECODER TO GET MASK\n",
    "onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)\n",
    "onnx_has_mask_input = np.zeros(1, dtype=np.float32)\n",
    "\n",
    "decoder = ort.InferenceSession(\"vit_b_decoder.onnx\")\n",
    "masks,_,_ = decoder.run(None,{\n",
    "    \"image_embeddings\": embeddings,\n",
    "    \"point_coords\": onnx_coord,\n",
    "    \"point_labels\": onnx_label,\n",
    "    \"mask_input\": onnx_mask_input,\n",
    "    \"has_mask_input\": onnx_has_mask_input,\n",
    "    \"orig_im_size\": np.array([orig_height, orig_width], dtype=np.float32)\n",
    "})"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "outputs": [
    {
     "data": {
      "text/plain": "(415, 612)"
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# POSTPROCESS MASK\n",
    "mask = masks[0][0]\n",
    "mask = (mask > 0).astype('uint8')*255\n",
    "mask.shape"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "outputs": [
    {
     "data": {
      "text/plain": "<PIL.Image.Image image mode=L size=612x415>",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAmQAAAGfCAAAAADUx5EnAAAIg0lEQVR4Ae3dYZuaOBSAUXf//3+2M874GHQkSrjxkpx+2GKBQA5vo51229PJNwIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgMIXA+TzFNMef5H9pp3gpLO/tpXVLeGP/J7yn4pasZQXGYTfTLhXXvF69wa/jXz30sA/rqDeefCU7na6xrQNfPr69duj6QPYGCKSP7KXK5BWQxn5DZo2syKbYfDLv+hFPTvTDXQSyRlZOvtZQbX85lu0PCBwhsgqLxipAH999iMjWMvIV2483VL2BQ0S28uF/rb/q5B3QR+AYkT2tbNGYr5P1aebtqxwkst/KzndvjovG3p67EzoJZP3ZvyWfrHPp9CjzXuYoK1leQXdWFRBZlcgBrQIiaxV0flVAZFUiB7QKDBSZz/2tMUSdP1BkUUTGbRUYJzILWWsLYeePE1kYkYFbBUTWKuj8qsAwkXm3rD7rjx0wTGQfE3ThqoDIqkQOaBUQWaug86sCIqsSOaBVQGStgs6vCiSNbMsfJ6vO1QEfEkga2fsasnzfrNcZOSPbUsyWc3opT36dnJFN/lBGm37KyCxKY2WWMrJtxNLc5hZ/VsbIttay9bx45cmvkDGyzY/k7v/K3DyOE/cVyPiHF9pWpIwz2veZHW60oVayi77lLF2ECX/ety1kP8IJp5Xu0fe7ofFWsoud5axfQvUrJfwpv8dKdpl4wrnVH8iIRwy6kl0eleUsSbEjR/bq38+e5FGMextjR3aymGVId/DILGYi6yFgMeuhvHqN4Veyr9nv9svVVUk7nwrkiywgiYAhn4La8SiQL7LHe2z/EZW1GzaMMEdkDUBObReYJDJLWXsq20eYJDIf/rcn0n5musii1pyocdsfwfgjpIssjFxlYbS1geeJrCZhf5hAtsisN2GP+nMDZ4vscxKuHCYwUWQWybCKKgMni0wIled1yN3JIjukoZuuCMwUmWWyEkPU7pkiizI0bkVgqsgsZZUagnZPFZnfwQyqqDLsXJFVMOyOEZgsMm+YMRmtjzpZZOsY9sYIzBaZpSymo9VRZ4tsFcPOGIHpIrOUxYS0Nup0ka1h2BcjMF9klrKYklZGnS+yFQy7YgREFuNq1EJgwsi8XxbPv8vmhJF1cXWRQkBkBYbNGAGRxbgatRAQWYFhM0ZAZDGuRi0ERFZg2IwRmDEyX8OIaenpqDNG9hTDjhgBkcW4GrUQEFmBYTNGQGQxrkYtBERWYNiMERBZjKtRCwGRFRg2YwREFuNq1EJAZAWGzRgBkcW4GrUQSBaZfza8eDbDbCaLbBhXEykEZozMclkE0GNzxsh6uLpGISCyAsNmjIDIYlyNWghki8znpeLhjLKZLbIOrjrugLy4RLrIJLB4PkO8SBfZEKomsRDIF5mlbPGARniRL7IRVM1hIZAwsuClLHj4ha4XF4GEkXkyowmIbLQnmnA+GSPzhpYwlJZbyhhZy3ycm1BAZAkfymi3NF9k/rqV7g3PF1l3YhdMGZlP/mOFmTKysYjNZsLIfCjrnf2EkfmXyEXWW8D1wgVyrmTBn/y9YYZ3tbhAzsgWt+jF0QXmjMxS1rXbOSPrSuxik0b2tZSdz9azTv0Hf8TePIs+AWSd/Wa2nCdOupLlfBij3pXIRn2yieY1d2R93pQTPe7P3ErWyHxa+kwPIVfN+zD/XmV+7vfvfRuA8k5/w2TSnpJY+VLS9/19b3x9f/7+z++3vSq7jXgd2ff7C6RWPj+/O5Xt30LYiM8fY9gl9xl4a2XfEy7OPez891HsM8qBkYtUXre6zvd68vX16yM48m2BUZCv0VwAfie1+LFfmdt8f/beXr9N54RXBQZCvlSznM9jZrf9fxz+Kprj3hO4ob93Xsajy19+/tzfQ2SL6T4en3FWA9zTQn2A+dxN4a6ywWd7N/k0L7N+xX8noGVVy1c7XcIwVYHx3W+L2fhzrT5uB0QJ+NOJUbLGJUCAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBAgQIECAAAECBJoF/gEH93HED9VxDwAAAABJRU5ErkJggg==\n",
      "image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAGfAmQBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+iiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiivdJv2dbtfBQuoNQkfxJsWY2bhFi+6C0O4E/OGzh9204AwAdw8PngmtbiW3uIpIZ4nKSRyKVZGBwQQeQQeMVHRRVixsLzU7yOzsLSe7upM7IYIzI7YBJwo5OACfwrU8S+D9f8AB9xBBr2myWbzoXiJdXVwDg4ZSRkcZGcjI9RWHRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRW54LghuvHXh63uIo5oJdTtkkjkUMrqZVBBB4II4xX3HXjfxt+GE3iS3XxFoNpG+qW6EXcManzLuMAYI7F1APGMsDjJ2qp+cNN0nUtZuGt9L0+7vp1Qu0drC0rBcgZIUE4yQM+4qnWx4Y8Map4u1yHSNIg824k5Zm4SJB1dz2UZH5gAEkA/V/wAO/hlpfw+s5TFJ9t1OfImvnj2EpnIRVydq8Ank5PJPCgeYftL31vJqPh2wWTN1DFPNIm0/KjlApz05Mb/l7ivB6KKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKK6DwJ/yUPw1/2FbX/wBGrX2/RRRRRXyh8fNZs9X+JTRWb+Z/Z9olnM4IKmQM7sAQT03hTnBDKwxxXl9FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFdB4E/5KH4a/7Ctr/6NWvt+iiiiivgi/vrjU9Rub+8k8y6upXmmfaBudiSxwOBkk9Kr0UUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUV7h+zt4R+2axd+K7lf3Njm2tOeszL87cN/CjYwQQfMyOVr6Pooooor5Q+Mfw8vPDfii81XTtJ8rw9c7ZY3tlLR27EKrK/8Azzy5yB93DAL0IHl9FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFfW/wKsbe0+FGmzQR7JLuWeac7id7iRoweenyoo49PXNekUUUUUUVw/if4SeDvFUs1zdaZ9lvpuWu7JvKcnduLEcozEk5ZlJOevAx4h4x+AviDw5Z/bdJn/ty3XaJI4LdlnUkkZEYLblHy8g556YBNeT0UUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUV9p/DHTYdK+GXh23gaRkeyS4JcgndKPNYcAcbnIHtjr1rrKKKKKKKKK8f8AjD8KbPXNHuNe0DT/AC9bt8ySQ2kYH21S2XyoxmQZZtwyzcrhiVx8wUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUVueDtAbxT4x0rRVWQpdXCrL5bKrLEPmkYFuMhAx79Oh6V9x0UUUUUUUUUV8kfHPRv7I+KF7KqQRw6hFHeRpCMYyNjFhgfMXR2PXO7Ock15vRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRXoHwS/wCSvaF/28f+k8lfX9FFFFFFFFFFed/GLwMvjHwdJPbxSPq2mI89mELHeODJHtAO4sq8DGdwXkAnPyJRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRXoHwS/5K9oX/bx/wCk8lfX9FFFFFFFFFFFfHHxb8MJ4V+IuoWttB5Njc4u7VRtwEfqFC42qHDqBgYCjr1PD0UUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUV6B8Ev+SvaF/28f8ApPJX1/RRRRRRRRRRRXzZ+0nps0XirRtUZo/IuLI26KCdwaNyzE8YxiVcc9j07+J0UUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUV6B8Ev+SvaF/wBvH/pPJX1/RRRRRRRRRRRXz/8AtNf8yt/29/8AtGvAKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKK2PCd9b6Z4y0O/vJPLtbXULeaZ9pO1FkUscDk4APSvueiiiiiiiiiiivmT9o3VWuvGun6YtzHJBZWQcxLtJilkY7t2OQSqxHB7YI68+N0UUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUV9Z/AzxQ3iH4fQ2lxJGbvSX+xkB13GIAGJioA2jb8g658snJOcemUUUUUUUUUUV8YfFTU/7X+KHiG58nytl2bbbu3Z8kCLdnA67M47Zxz1rj6KKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKK9w/Zr1PyvEOuaT5OftNolz5u77vlPt24xznzs5zxt754+j6KKKKKKKKKK+RPjV4cm0L4kajcLYyQWGov9pt5CSyysVUykHJ58wsSvbI4AIrzuiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiu8+DWqrpPxU0Z5bmSCC4d7V9u7EhdCqIwHUGTZ14BAPGM19h0UUUUUUUUUVwfxZ8DN448HPBZxRtq1m/n2ZYqu49Gj3EcBl7ZA3KmSAK+cPF/wu8R+CNGsdU1dbTyLpxGVhm3NDIV3BH4AzgNypYfKeemeLooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooqxYX1xpmo21/ZyeXdWsqTQvtB2upBU4PBwQOtfRdh+0bop8NefqGmXY1pECtbQAeVK+G+ZXJyqZVc5BI3jAfBNR2v7Smjvp073nh++ivhu8mGGZJI34+Xc52lcnIOFbA556VTvf2lof7LjNh4bkGoOkgcT3IMULfwEEDMg7kfJ6AnrVTwx+0Zfy65DF4nsLGLTJPkeayikDwk9HILNuUdwBnnIzjaff7G/s9Ts47ywu4Lu1kzsmgkEiNgkHDDg4II/CrFFFFFRzzw2tvLcXEscMESF5JJGCqigZJJPAAHOa5fRPib4L8Q3H2fTvEFo0+9EWObdA0jMcKEEgUuSRjC56j1FU/i9o39t/C/WolSAzW0QvI3mH3PKIdipwcMUDqP97GQCa+OKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKkgnmtbiK4t5ZIZ4nDxyRsVZGByCCOQQec13EPxj8bweGjoqaxIQXY/bXy91sYEFPMYk4ychvvA4wwAArk9T13WNb8r+1tVvr/wAnPl/a7h5dmcZxuJxnA6egrqPhr8RrzwFriySefc6RLlbmzWUgDdtzIi52+YNo5PUZGRnI+i4PjL4Dm0aLU312OBHcRtBJE5nRyu7BjUE4HTcMrngMa87+Kfxt06/0OTQ/Cc32lb6Jo7u8eFlCRtkGNVcAliM5bGADxknK+AV2k3xU8Vz+Ch4Ue9jFgEWHzEjCS+QqhRDuXA2YHPG48gkgkVxdFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFf/Z\n"
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# VISUALIZE MASK\n",
    "img_mask = Image.fromarray(mask, \"L\")\n",
    "img_mask"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# 3.3 OPTION 3: Use box and point as a prompt"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "outputs": [
    {
     "data": {
      "text/plain": "array([[[234.24837, 267.56625],\n        [220.86275, 262.5494 ],\n        [428.33987, 543.49396]]], dtype=float32)"
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# ENCODE PROMPT (box)\n",
    "input_box = np.array([132, 157, 256, 325]).reshape(2,2)\n",
    "box_labels = np.array([2,3])\n",
    "input_point = np.array([[140, 160]])\n",
    "input_label = np.array([0])\n",
    "\n",
    "onnx_coord = np.concatenate([input_point, input_box], axis=0)[None, :, :]\n",
    "onnx_label = np.concatenate([input_label, box_labels], axis=0)[None, :].astype(np.float32)\n",
    "\n",
    "coords = deepcopy(onnx_coord).astype(float)\n",
    "coords[..., 0] = coords[..., 0] * (resized_width / orig_width)\n",
    "coords[..., 1] = coords[..., 1] * (resized_height / orig_height)\n",
    "\n",
    "onnx_coord = coords.astype(\"float32\")\n",
    "onnx_coord"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "outputs": [],
   "source": [
    "# RUN DECODER TO GET MASK\n",
    "onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)\n",
    "onnx_has_mask_input = np.zeros(1, dtype=np.float32)\n",
    "\n",
    "decoder = ort.InferenceSession(\"vit_b_decoder.onnx\")\n",
    "masks,_,_ = decoder.run(None,{\n",
    "    \"image_embeddings\": embeddings,\n",
    "    \"point_coords\": onnx_coord,\n",
    "    \"point_labels\": onnx_label,\n",
    "    \"mask_input\": onnx_mask_input,\n",
    "    \"has_mask_input\": onnx_has_mask_input,\n",
    "    \"orig_im_size\": np.array([orig_height, orig_width], dtype=np.float32)\n",
    "})"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "outputs": [
    {
     "data": {
      "text/plain": "(415, 612)"
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# POSTPROCESS MASK\n",
    "mask = masks[0][0]\n",
    "mask = (mask > 0).astype('uint8')*255\n",
    "mask.shape"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "outputs": [
    {
     "data": {
      "text/plain": "<PIL.Image.Image image mode=L size=612x415>",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAmQAAAGfCAAAAADUx5EnAAAIa0lEQVR4Ae3dgXLTOBRA0bLD//9yNwlQpSlRXOMnP+kdZnbGwY5sHV3stMuUtze/CBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIDAYQI/Dhvp8IHe394SX93h0114wP/Szu3S2Nv1P7+mF8gb2Y1WZdMXdplA2gfSn7y2XuDl+K2HrrBuU80h+Z1s6xPz/drkny6nWoAKF5s+sk3pyCt1q1kju8vmbvMJ5esjnrzRbw8RyBrZ/eRfNHR7VN4fbzuZwAyR9cleJNh/s70jBKaIrNdRb98IQOd4LTBFZJ0P/xp7vcanHzFHZE8r09jpBW24gEkie1LZ58Z8M3bDgp9xyM8zTrrnnJ972jOC95wlMMud7Cwf5z1AQGQHIBqiLyCyvo+9BwiI7ABEQ/QFForMF5f9pT5v70KRnYfozH2BdSJzI+uv9Il714nsRESn7guIrO9j7wECy0TmaXlADUFDLBNZkI9hDxAQ2QGIhugLiKzvY+8BAiI7ANEQfQGR9X3sPUAgaWT+9tgBa5tmiKSRfd9Hlt83G/WOnJHtKWbPe0YpFz9PzsiKL8pq008ZmZvSWpmljGwfsTT3ucW/K2Nke2vZ+7545eJnyBjZ7iVR2W660Ddm/MsL/9ZKxhmFLmH+wZe6k924/SSpdNVl/HP/b3eyG3HGaaVb+2EXlHA1DmjswpdwYsMWNduJ1ntc/hb21MyT2rKRXX4Y9jG3xDxrNe2VLByZzLJUuXRkMsuR2eKR+RckMmS2fGRPfkZjBvsy17B+ZL4AOD3mApF5ZJ5dWb7IIr7xEDHm2Ss30fnzRTYRnkvdJlAkMreybTnEHFUkshg8o24TSBdZ0D0naNhtyNWPShdZ1IKoLEr29bhlIvNN2dcxRB1RJ7IoQeO+FMgWmafayyWb74BskQUK6jcQtzt0ssiE0F2tSXcmi2xSRZfdFagUmdtkN4W4nZUii1M0clegVGRuZd0WwnaWiixM0cBdgVqRuZV1Y4jaWSsy/28pqqPuuMUi61rYGSRQLTIPzKCQesNWi6xnYV+QQLnI3MqCSuoMWy6yjoVdQQIiC4I1bBOoF5nnZVv9QVv1IhsE6zRNQGTNwlaQQMHIPC+DWno6bMHInlrYESQgsiBYwzYBkTULW0ECIguCNWwTEFmzsBUkILIgWMM2gYqR+R5GW/8hWxUjGwLrJE1AZM3CVpCAyIJgDdsERNYsbAUJiCwI1rBNQGTNwlaQgMiCYA3bBETWLGwFCSSLzL8cHrTOpw6bLLIhFkoewtxOUjGyNntbQwRENoS59klEVnv9h8xeZEOYa59EZLXXf8jss0U24Cu/AacYsnTznCRbZPPIudLNAukic5/ZvHbTHJgusmnkXOhmgXyRRd/KosffTF/nwHyR1bEvM9OEkbnVrFZfwshWIzafjJG5lS3WZcbIFiM2HZFpIFxAZOHETiAyDYQL1IvMj1sJj+rxBCkj8+Xl4zLN/TplZHOTuvpHAZE9inh9uEDByHwoO7yiFwMWjMy/Ef2iicN354zMJ//DF/rMAXNGFizigRkM/DB8ycgeDLwMFqgZmVtZcFafh68Zmc/+nysIfpX1I3b0vebH9QRZJx+85sOHL3on+3Uriy55+GomPWHVyJIux5qXlTUyT7KFessa2ZjPS56XQ1JOfMf4awHX6/3rjn1aiae/b0Ip35VZ+RrT5fp+NXX9cvDjYg/L7GPElIuzykWlVn5/fnX7M/txfyd8foJVFjjDPOZV3pnZbcLtvfPOP0M9G69hYuRWysa5Xg9r8729vb38xhgO/abAz28en+jwh0Dumvu95+53El12wUt5WKmZBa5NPUzna2Z3B1x23r2aeebZr30l5vcv0XyJbKXpZk+rXd/i6g+VLT7btqy5tpZnv89s+cnmauvjagq4f2RWYK4f62pjtMBHZqNP7HwECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAAQIECBAgQIAAgRcC/wPfu23OLIEm+gAAAABJRU5ErkJggg==\n",
      "image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAGfAmQBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+iiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiveLr9naUeBoLmzvJz4mESyzWkzIIWYjLRKR0YZADFipK9g2V8LngmtbiW3uIpIZ4nKSRyKVZGBwQQeQQeMVHRRVixsLzU7yOzsLSe7upM7IYIzI7YBJwo5OACfwrU8S+D9f8AB9xBBr2myWbzoXiJdXVwDg4ZSRkcZGcjI9RWHRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRW54LghuvHXh63uIo5oJdTtkkjkUMrqZVBBB4II4xX3HXjfxt+GE3iS3XxFoNpG+qW6EXcManzLuMAYI7F1APGMsDjJ2qp+ZKK2PDHhjVPF2uQ6RpEHm3EnLM3CRIOrueyjI/MAAkgH638C/DfQvAdmPsMPnalJEI7m/kzvl5ycDJCLn+Ef3VyWIzXln7TM8LXHhq3WWMzoly7xhhuVWMQUkdQCVYA99p9K8DooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooroPAn/JQ/DX/YVtf/Rq19v0UUUUV8ofHzWbPV/iU0Vm/mf2faJZzOCCpkDO7AEE9N4U5wQysMcV5fRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRXQeBP+Sh+Gv+wra/8Ao1a+36KKKKK+CL++uNT1G5v7yTzLq6leaZ9oG52JLHA4GST0qvRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRXuH7O3hH7ZrF34ruV/c2Oba056zMvztw38KNjBBB8zI5Wvo+iiiiivkz4z+A5vCviybUbLTo7fQb9wbYwZ2RybRvQj+AlgzAdMH5fukDzOiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiivrP4DabDY/CuzuImkL39xNcShiMBg5iwvHTbGp5zyT9B6ZRRRRRRXD+J/hJ4O8VSzXN1pn2W+m5a7sm8pyd24sRyjMSTlmUk568DHhHi74F+KfDm6401P7bsRj57SMiZfujmLJJ5J+6W4Uk4ry+iiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiivs/4V6Z/ZHwv8PW3nebvtBc7tu3HnEy7cZPTfjPfGeOldhRRRRRRRRXi/xk+FWl3uh6l4p0e1+zatb7rq6WM4S5TjzGIJAVgAXyv3vmyGLAj5ooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooorY8K+HrjxX4o07Q7Vtkl3KEL4B8tAMu+CRnaoY4zzjA5r7noooooooooor5A+M3hi38L/ABFuobKCC3sbyJLu3ghziMNlWGD9350cgDgAjGOg8/ooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooor0D4Jf8le0L/t4/9J5K+v6KKKKKKKKKK83+M3gV/GXhIXFkM6npe+eBQrMZUK/PEoX+JtqkcHlQONxNfJFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFegfBL/AJK9oX/bx/6TyV9f0UUUUUUUUUUV8YfFHw7/AMIx8RdXsY4vLtZJftNsFh8pPLk+YKg6bVJKZHHyHp0HH0UUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUV6B8Ev+SvaF/28f8ApPJX1/RRRRRRRRRRRXzR+0jY3EfjLSb9o8Ws2n+TG+4fM6SOWGOvAkT8/Y14vRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRXefBieG2+LegvPLHEheVAzsFBZoXVRz3LEADuSBX2HRRRRRRRRRRRXz/+01/zK3/b3/7RrwCiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiitjwnfW+meMtDv7yTy7W11C3mmfaTtRZFLHA5OAD0r7nooooooooooor5Y/aGvre7+JUcMEm+S00+KGcbSNjlnkA56/K6nj19c15PRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRX1f8C/F3/CR+Bk024bN9o222fj70JB8puFAHAKYyT+7yfvV6hRRRRRRRRRRXxR8R7641D4leI5rqTzJF1CaEHaBhI2MaDj0VVHvjnmuXooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooor3D9mvU/K8Q65pPk5+02iXPm7vu+U+3bjHOfOznPG3vnj6Poooooooooor44+L2jf2J8UNaiVJxDcyi8jeYff80B2KnAyocuo/wB3GSQa4eiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiu8+DWqrpPxU0Z5bmSCC4d7V9u7EhdCqIwHUGTZ14BAPGM19h0UUUUUUUUUV5v8AGbwK/jLwkLiyGdT0vfPAoVmMqFfniUL/ABNtUjg8qBxuJr5s8S+AfFHhC3guNd0mS1gncpHIJEkXcBnBKMQDjJAOM4OOhrm6KKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKkgnmtbiK4t5ZIZ4nDxyRsVZGByCCOQQec19B+Hv2ibOHwkTr9nPPrsH7tVt1CpdfKSHJ6R8gBsA/eBUEZVdCx/aR8NyWcbX+jarBdHO+OARyovJxhiyk8Y/hHpz1rP1P8AaUsxp0X9k+H52vniPmfa5gscMmBjG3JkXOeuwkAdM8R+FP2jVmuBb+K9MjgR3wt3YBiqAlR80bEnA+YlgSegCnrXuGlarY65pdvqemXMdzZ3Cb4pU6MP5gg5BB5BBBwRVyiiiio554bW3luLiWOGCJC8kkjBVRQMkkngADnNcPofxj8Ea7cXMCaxHZPC7BTf4gWVAQN6sxxg54UkNwcqMVT+NFlZ658I768ieCdYPJvbWYThU+8BuVsgNlHcAc53DAJxXyRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRXaQfFXxfa+DovDVvqkkNtE48u5jLLcJGOkQkByEB/HHy52/LXN6nrusa35X9rarfX/k58v7XcPLszjONxOM4HT0FdJ8PPiPqXgDVGmiSS90+RHEtg1wyIWO3516gP8qjJU8ZHfI+h7D43+A7zS/t02qyWRDhGt7i3cyqTuxwgYEYUnKkgZUHBIFcH8U/jbp1/ocmh+E5vtK30TR3d48LKEjbIMaq4BLEZy2MAHjJOV8AroLrxx4lvvC8Hhq51eeTSIdoS3IXopyqlsbmUdlJIGFwPlGOfooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooor/2Q==\n"
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "img_mask = Image.fromarray(mask, \"L\")\n",
    "img_mask"
   ],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
